import tensorflow as tf
import numpy as np
import re
from baselines.acktr.kfac_utils import *
from functools import reduce

KFAC_OPS = ['MatMul', 'Conv2D', 'BiasAdd']
KFAC_DEBUG = True


class KfacOptimizer():

  def __init__(self, learning_rate=0.01, momentum=0.9, clip_kl=0.01, kfac_update=2, stats_accum_iter=60, full_stats_init=False, cold_iter=100, cold_lr=None, async=False, async_stats=False, epsilon=1e-2, stats_decay=0.95, blockdiag_bias=False, channel_fac=False, factored_damping=False, approxT2=False, use_float64=False, weight_decay_dict={},max_grad_norm=0.5):
    self.max_grad_norm = max_grad_norm
    self._lr = learning_rate
    self._momentum = momentum
    self._clip_kl = clip_kl
    self._channel_fac = channel_fac
    self._kfac_update = kfac_update
    self._async = async
    self._async_stats = async_stats
    self._epsilon = epsilon
    self._stats_decay = stats_decay
    self._blockdiag_bias = blockdiag_bias
    self._approxT2 = approxT2
    self._use_float64 = use_float64
    self._factored_damping = factored_damping
    self._cold_iter = cold_iter
    if cold_lr == None:
      # good heuristics
      self._cold_lr = self._lr# * 3.
    else:
      self._cold_lr = cold_lr
    self._stats_accum_iter = stats_accum_iter
    self._weight_decay_dict = weight_decay_dict
    self._diag_init_coeff = 0.
    self._full_stats_init = full_stats_init
    if not self._full_stats_init:
      self._stats_accum_iter = self._cold_iter

    self.sgd_step = tf.Variable(0, name='KFAC/sgd_step', trainable=False)
    self.global_step = tf.Variable(
      0, name='KFAC/global_step', trainable=False)
    self.cold_step = tf.Variable(0, name='KFAC/cold_step', trainable=False)
    self.factor_step = tf.Variable(
      0, name='KFAC/factor_step', trainable=False)
    self.stats_step = tf.Variable(
      0, name='KFAC/stats_step', trainable=False)
    self.vFv = tf.Variable(0., name='KFAC/vFv', trainable=False)

    self.factors = {}
    self.param_vars = []
    self.stats = {}
    self.stats_eigen = {}

  def getFactors(self, g, varlist):
    graph = tf.get_default_graph()
    factorTensors = {}
    fpropTensors = []
    bpropTensors = []
    opTypes = []
    fops = []

    def searchFactors(gradient, graph):
      # hard coded search stratergy
      bpropOp = gradient.op
      bpropOp_name = bpropOp.name

      bTensors = []
      fTensors = []

      # combining additive gradient, assume they are the same op type and
      # indepedent
      if 'AddN' in bpropOp_name:
        factors = []
        for g in gradient.op.inputs:
          factors.append(searchFactors(g, graph))
        op_names = [item['opName'] for item in factors]
        # TO-DO: need to check all the attribute of the ops as well
        print (gradient.name)
        print (op_names)
        print (len(np.unique(op_names)))
        assert len(np.unique(op_names)) == 1, gradient.name + \
                                              ' is shared among different computation OPs'

        bTensors = reduce(lambda x, y: x + y,
                          [item['bpropFactors'] for item in factors])
        if len(factors[0]['fpropFactors']) > 0:
          fTensors = reduce(
            lambda x, y: x + y, [item['fpropFactors'] for item in factors])
        fpropOp_name = op_names[0]
        fpropOp = factors[0]['op']
      else:
        fpropOp_name = re.search(
          'gradientsSampled(_[0-9]+|)/(.+?)_grad', bpropOp_name).group(2)
        fpropOp = graph.get_operation_by_name(fpropOp_name)
        if fpropOp.op_def.name in KFAC_OPS:
          # Known OPs
          ###
          bTensor = [
            i for i in bpropOp.inputs if 'gradientsSampled' in i.name][-1]
          bTensorShape = fpropOp.outputs[0].get_shape()
          if bTensor.get_shape()[0].value == None:
            bTensor.set_shape(bTensorShape)
          bTensors.append(bTensor)
          ###
          if fpropOp.op_def.name == 'BiasAdd':
            fTensors = []
          else:
            fTensors.append(
              [i for i in fpropOp.inputs if param.op.name not in i.name][0])
          fpropOp_name = fpropOp.op_def.name
        else:
          # unknown OPs, block approximation used
          bInputsList = [i for i in bpropOp.inputs[
            0].op.inputs if 'gradientsSampled' in i.name if 'Shape' not in i.name]
          if len(bInputsList) > 0:
            bTensor = bInputsList[0]
            bTensorShape = fpropOp.outputs[0].get_shape()
            if len(bTensor.get_shape()) > 0 and bTensor.get_shape()[0].value == None:
              bTensor.set_shape(bTensorShape)
            bTensors.append(bTensor)
          fpropOp_name = opTypes.append('UNK-' + fpropOp.op_def.name)

      return {'opName': fpropOp_name, 'op': fpropOp, 'fpropFactors': fTensors, 'bpropFactors': bTensors}

    for t, param in zip(g, varlist):
      if KFAC_DEBUG:
        print(('get factor for '+param.name))
      factors = searchFactors(t, graph)
      factorTensors[param] = factors

    ########
    # check associated weights and bias for homogeneous coordinate representation
    # and check redundent factors
    # TO-DO: there may be a bug to detect associate bias and weights for
    # forking layer, e.g. in inception models.
    for param in varlist:
      factorTensors[param]['assnWeights'] = None
      factorTensors[param]['assnBias'] = None
    for param in varlist:
      if factorTensors[param]['opName'] == 'BiasAdd':
        factorTensors[param]['assnWeights'] = None
        for item in varlist:
          if len(factorTensors[item]['bpropFactors']) > 0:
            if (set(factorTensors[item]['bpropFactors']) == set(factorTensors[param]['bpropFactors'])) and (len(factorTensors[item]['fpropFactors']) > 0):
              factorTensors[param]['assnWeights'] = item
              factorTensors[item]['assnBias'] = param
              factorTensors[param]['bpropFactors'] = factorTensors[
                item]['bpropFactors']

    ########

    ########
    # concatenate the additive gradients along the batch dimension, i.e.
    # assuming independence structure
    for key in ['fpropFactors', 'bpropFactors']:
      for i, param in enumerate(varlist):
        if len(factorTensors[param][key]) > 0:
          if (key + '_concat') not in factorTensors[param]:
            name_scope = factorTensors[param][key][0].name.split(':')[
              0]
            with tf.name_scope(name_scope):
              factorTensors[param][
                key + '_concat'] = tf.concat(factorTensors[param][key], 0)
        else:
          factorTensors[param][key + '_concat'] = None
        for j, param2 in enumerate(varlist[(i + 1):]):
          if (len(factorTensors[param][key]) > 0) and (set(factorTensors[param2][key]) == set(factorTensors[param][key])):
            factorTensors[param2][key] = factorTensors[param][key]
            factorTensors[param2][
              key + '_concat'] = factorTensors[param][key + '_concat']
    ########

    if KFAC_DEBUG:
      for items in zip(varlist, fpropTensors, bpropTensors, opTypes):
        print((items[0].name, factorTensors[item]))
    self.factors = factorTensors
    return factorTensors

  def getStats(self, factors, varlist):
    if len(self.stats) == 0:
      # initialize stats variables on CPU because eigen decomp is
      # computed on CPU
      with tf.device('/cpu'):
        tmpStatsCache = {}

        # search for tensor factors and
        # use block diag approx for the bias units
        for var in varlist:
          fpropFactor = factors[var]['fpropFactors_concat']
          bpropFactor = factors[var]['bpropFactors_concat']
          opType = factors[var]['opName']
          if opType == 'Conv2D':
            Kh = var.get_shape()[0]
            Kw = var.get_shape()[1]
            C = fpropFactor.get_shape()[-1]

            Oh = bpropFactor.get_shape()[1]
            Ow = bpropFactor.get_shape()[2]
            if Oh == 1 and Ow == 1 and self._channel_fac:
              # factorization along the channels do not support
              # homogeneous coordinate
              var_assnBias = factors[var]['assnBias']
              if var_assnBias:
                factors[var]['assnBias'] = None
                factors[var_assnBias]['assnWeights'] = None
        ##

        for var in varlist:
          fpropFactor = factors[var]['fpropFactors_concat']
          bpropFactor = factors[var]['bpropFactors_concat']
          opType = factors[var]['opName']
          self.stats[var] = {'opName': opType,
                             'fprop_concat_stats': [],
                             'bprop_concat_stats': [],
                             'assnWeights': factors[var]['assnWeights'],
                             'assnBias': factors[var]['assnBias'],
                             }
          if fpropFactor is not None:
            if fpropFactor not in tmpStatsCache:
              if opType == 'Conv2D':
                Kh = var.get_shape()[0]
                Kw = var.get_shape()[1]
                C = fpropFactor.get_shape()[-1]

                Oh = bpropFactor.get_shape()[1]
                Ow = bpropFactor.get_shape()[2]
                if Oh == 1 and Ow == 1 and self._channel_fac:
                  # factorization along the channels
                  # assume independence bewteen input channels and spatial
                  # 2K-1 x 2K-1 covariance matrix and C x C covariance matrix
                  # factorization along the channels do not
                  # support homogeneous coordinate, assnBias
                  # is always None
                  fpropFactor2_size = Kh * Kw
                  slot_fpropFactor_stats2 = tf.Variable(tf.diag(tf.ones(
                    [fpropFactor2_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False)
                  self.stats[var]['fprop_concat_stats'].append(
                    slot_fpropFactor_stats2)

                  fpropFactor_size = C
                else:
                  # 2K-1 x 2K-1 x C x C covariance matrix
                  # assume BHWC
                  fpropFactor_size = Kh * Kw * C
              else:
                # D x D covariance matrix
                fpropFactor_size = fpropFactor.get_shape()[-1]

              # use homogeneous coordinate
              if not self._blockdiag_bias and self.stats[var]['assnBias']:
                fpropFactor_size += 1

              slot_fpropFactor_stats = tf.Variable(tf.diag(tf.ones(
                [fpropFactor_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False)
              self.stats[var]['fprop_concat_stats'].append(
                slot_fpropFactor_stats)
              if opType != 'Conv2D':
                tmpStatsCache[fpropFactor] = self.stats[
                  var]['fprop_concat_stats']
            else:
              self.stats[var][
                'fprop_concat_stats'] = tmpStatsCache[fpropFactor]

          if bpropFactor is not None:
            # no need to collect backward stats for bias vectors if
            # using homogeneous coordinates
            if not((not self._blockdiag_bias) and self.stats[var]['assnWeights']):
              if bpropFactor not in tmpStatsCache:
                slot_bpropFactor_stats = tf.Variable(tf.diag(tf.ones([bpropFactor.get_shape(
                )[-1]])) * self._diag_init_coeff, name='KFAC_STATS/' + bpropFactor.op.name, trainable=False)
                self.stats[var]['bprop_concat_stats'].append(
                  slot_bpropFactor_stats)
                tmpStatsCache[bpropFactor] = self.stats[
                  var]['bprop_concat_stats']
              else:
                self.stats[var][
                  'bprop_concat_stats'] = tmpStatsCache[bpropFactor]

    return self.stats

  def compute_and_apply_stats(self, loss_sampled, var_list=None):
    varlist = var_list
    if varlist is None:
      varlist = tf.trainable_variables()

    stats = self.compute_stats(loss_sampled, var_list=varlist)
    return self.apply_stats(stats)

  def compute_stats(self, loss_sampled, var_list=None):
    varlist = var_list
    if varlist is None:
      varlist = tf.trainable_variables()

    gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled')
    self.gs = gs
    factors = self.getFactors(gs, varlist)
    stats = self.getStats(factors, varlist)

    updateOps = []
    statsUpdates = {}
    statsUpdates_cache = {}
    for var in varlist:
      opType = factors[var]['opName']
      fops = factors[var]['op']
      fpropFactor = factors[var]['fpropFactors_concat']
      fpropStats_vars = stats[var]['fprop_concat_stats']
      bpropFactor = factors[var]['bpropFactors_concat']
      bpropStats_vars = stats[var]['bprop_concat_stats']
      SVD_factors = {}
      for stats_var in fpropStats_vars:
        stats_var_dim = int(stats_var.get_shape()[0])
        if stats_var not in statsUpdates_cache:
          old_fpropFactor = fpropFactor
          B = (tf.shape(fpropFactor)[0])  # batch size
          if opType == 'Conv2D':
            strides = fops.get_attr("strides")
            padding = fops.get_attr("padding")
            convkernel_size = var.get_shape()[0:3]

            KH = int(convkernel_size[0])
            KW = int(convkernel_size[1])
            C = int(convkernel_size[2])
            flatten_size = int(KH * KW * C)

            Oh = int(bpropFactor.get_shape()[1])
            Ow = int(bpropFactor.get_shape()[2])

            if Oh == 1 and Ow == 1 and self._channel_fac:
              # factorization along the channels
              # assume independence among input channels
              # factor = B x 1 x 1 x (KH xKW x C)
              # patches = B x Oh x Ow x (KH xKW x C)
              if len(SVD_factors) == 0:
                if KFAC_DEBUG:
                  print(('approx %s act factor with rank-1 SVD factors' % (var.name)))
                # find closest rank-1 approx to the feature map
                S, U, V = tf.batch_svd(tf.reshape(
                  fpropFactor, [-1, KH * KW, C]))
                # get rank-1 approx slides
                sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
                patches_k = U[:, :, 0] * sqrtS1  # B x KH*KW
                full_factor_shape = fpropFactor.get_shape()
                patches_k.set_shape(
                  [full_factor_shape[0], KH * KW])
                patches_c = V[:, :, 0] * sqrtS1  # B x C
                patches_c.set_shape([full_factor_shape[0], C])
                SVD_factors[C] = patches_c
                SVD_factors[KH * KW] = patches_k
              fpropFactor = SVD_factors[stats_var_dim]

            else:
              # poor mem usage implementation
              patches = tf.extract_image_patches(fpropFactor, ksizes=[1, convkernel_size[
                0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding)

              if self._approxT2:
                if KFAC_DEBUG:
                  print(('approxT2 act fisher for %s' % (var.name)))
                # T^2 terms * 1/T^2, size: B x C
                fpropFactor = tf.reduce_mean(patches, [1, 2])
              else:
                # size: (B x Oh x Ow) x C
                fpropFactor = tf.reshape(
                  patches, [-1, flatten_size]) / Oh / Ow
          fpropFactor_size = int(fpropFactor.get_shape()[-1])
          if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias:
            if opType == 'Conv2D' and not self._approxT2:
              # correct padding for numerical stability (we
              # divided out OhxOw from activations for T1 approx)
              fpropFactor = tf.concat([fpropFactor, tf.ones(
                [tf.shape(fpropFactor)[0], 1]) / Oh / Ow], 1)
            else:
              # use homogeneous coordinates
              fpropFactor = tf.concat(
                [fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1])], 1)

          # average over the number of data points in a batch
          # divided by B
          cov = tf.matmul(fpropFactor, fpropFactor,
                          transpose_a=True) / tf.cast(B, tf.float32)
          updateOps.append(cov)
          statsUpdates[stats_var] = cov
          if opType != 'Conv2D':
            # HACK: for convolution we recompute fprop stats for
            # every layer including forking layers
            statsUpdates_cache[stats_var] = cov

      for stats_var in bpropStats_vars:
        stats_var_dim = int(stats_var.get_shape()[0])
        if stats_var not in statsUpdates_cache:
          old_bpropFactor = bpropFactor
          bpropFactor_shape = bpropFactor.get_shape()
          B = tf.shape(bpropFactor)[0]  # batch size
          C = int(bpropFactor_shape[-1])  # num channels
          if opType == 'Conv2D' or len(bpropFactor_shape) == 4:
            if fpropFactor is not None:
              if self._approxT2:
                if KFAC_DEBUG:
                  print(('approxT2 grad fisher for %s' % (var.name)))
                bpropFactor = tf.reduce_sum(
                  bpropFactor, [1, 2])  # T^2 terms * 1/T^2
              else:
                bpropFactor = tf.reshape(
                  bpropFactor, [-1, C]) * Oh * Ow  # T * 1/T terms
            else:
              # just doing block diag approx. spatial independent
              # structure does not apply here. summing over
              # spatial locations
              if KFAC_DEBUG:
                print(('block diag approx fisher for %s' % (var.name)))
              bpropFactor = tf.reduce_sum(bpropFactor, [1, 2])

          # assume sampled loss is averaged. TO-DO:figure out better
          # way to handle this
          bpropFactor *= tf.to_float(B)
          ##

          cov_b = tf.matmul(
            bpropFactor, bpropFactor, transpose_a=True) / tf.to_float(tf.shape(bpropFactor)[0])

          updateOps.append(cov_b)
          statsUpdates[stats_var] = cov_b
          statsUpdates_cache[stats_var] = cov_b

    if KFAC_DEBUG:
      aKey = list(statsUpdates.keys())[0]
      statsUpdates[aKey] = tf.Print(statsUpdates[aKey],
                                    [tf.convert_to_tensor('step:'),
                                     self.global_step,
                                     tf.convert_to_tensor(
                                       'computing stats'),
                                     ])
    self.statsUpdates = statsUpdates
    return statsUpdates

  def apply_stats(self, statsUpdates):
    """ compute stats and update/apply the new stats to the running average
    """

    def updateAccumStats():
      if self._full_stats_init:
        return tf.cond(tf.greater(self.sgd_step, self._cold_iter), lambda: tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter)), tf.no_op)
      else:
        return tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter))

    def updateRunningAvgStats(statsUpdates, fac_iter=1):
      # return tf.cond(tf.greater_equal(self.factor_step,
      # tf.convert_to_tensor(fac_iter)), lambda:
      # tf.group(*self._apply_stats(stats_list, varlist)), tf.no_op)
      return tf.group(*self._apply_stats(statsUpdates))

    if self._async_stats:
      # asynchronous stats update
      update_stats = self._apply_stats(statsUpdates)

      queue = tf.FIFOQueue(1, [item.dtype for item in update_stats], shapes=[
        item.get_shape() for item in update_stats])
      enqueue_op = queue.enqueue(update_stats)

      def dequeue_stats_op():
        return queue.dequeue()
      self.qr_stats = tf.train.QueueRunner(queue, [enqueue_op])
      update_stats_op = tf.cond(tf.equal(queue.size(), tf.convert_to_tensor(
        0)), tf.no_op, lambda: tf.group(*[dequeue_stats_op(), ]))
    else:
      # synchronous stats update
      update_stats_op = tf.cond(tf.greater_equal(
        self.stats_step, self._stats_accum_iter), lambda: updateRunningAvgStats(statsUpdates), updateAccumStats)
    self._update_stats_op = update_stats_op
    return update_stats_op

  def _apply_stats(self, statsUpdates, accumulate=False, accumulateCoeff=0.):
    updateOps = []
    # obtain the stats var list
    for stats_var in statsUpdates:
      stats_new = statsUpdates[stats_var]
      if accumulate:
        # simple superbatch averaging
        update_op = tf.assign_add(
          stats_var, accumulateCoeff * stats_new, use_locking=True)
      else:
        # exponential running averaging
        update_op = tf.assign(
          stats_var, stats_var * self._stats_decay, use_locking=True)
        update_op = tf.assign_add(
          update_op, (1. - self._stats_decay) * stats_new, use_locking=True)
      updateOps.append(update_op)

    with tf.control_dependencies(updateOps):
      stats_step_op = tf.assign_add(self.stats_step, 1)

    if KFAC_DEBUG:
      stats_step_op = (tf.Print(stats_step_op,
                                [tf.convert_to_tensor('step:'),
                                 self.global_step,
                                 tf.convert_to_tensor('fac step:'),
                                 self.factor_step,
                                 tf.convert_to_tensor('sgd step:'),
                                 self.sgd_step,
                                 tf.convert_to_tensor('Accum:'),
                                 tf.convert_to_tensor(accumulate),
                                 tf.convert_to_tensor('Accum coeff:'),
                                 tf.convert_to_tensor(accumulateCoeff),
                                 tf.convert_to_tensor('stat step:'),
                                 self.stats_step, updateOps[0], updateOps[1]]))
    return [stats_step_op, ]

  def getStatsEigen(self, stats=None):
    if len(self.stats_eigen) == 0:
      stats_eigen = {}
      if stats is None:
        stats = self.stats

      tmpEigenCache = {}
      with tf.device('/cpu:0'):
        for var in stats:
          for key in ['fprop_concat_stats', 'bprop_concat_stats']:
            for stats_var in stats[var][key]:
              if stats_var not in tmpEigenCache:
                stats_dim = stats_var.get_shape()[1].value
                e = tf.Variable(tf.ones(
                  [stats_dim]), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/e', trainable=False)
                Q = tf.Variable(tf.diag(tf.ones(
                  [stats_dim])), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/Q', trainable=False)
                stats_eigen[stats_var] = {'e': e, 'Q': Q}
                tmpEigenCache[
                  stats_var] = stats_eigen[stats_var]
              else:
                stats_eigen[stats_var] = tmpEigenCache[
                  stats_var]
      self.stats_eigen = stats_eigen
    return self.stats_eigen

  def computeStatsEigen(self):
    """ compute the eigen decomp using copied var stats to avoid concurrent read/write from other queue """
    # TO-DO: figure out why this op has delays (possibly moving
    # eigenvectors around?)
    with tf.device('/cpu:0'):
      def removeNone(tensor_list):
        local_list = []
        for item in tensor_list:
          if item is not None:
            local_list.append(item)
        return local_list

      def copyStats(var_list):
        print("copying stats to buffer tensors before eigen decomp")
        redundant_stats = {}
        copied_list = []
        for item in var_list:
          if item is not None:
            if item not in redundant_stats:
              if self._use_float64:
                redundant_stats[item] = tf.cast(
                  tf.identity(item), tf.float64)
              else:
                redundant_stats[item] = tf.identity(item)
            copied_list.append(redundant_stats[item])
          else:
            copied_list.append(None)
        return copied_list
      #stats = [copyStats(self.fStats), copyStats(self.bStats)]
      #stats = [self.fStats, self.bStats]

      stats_eigen = self.stats_eigen
      computedEigen = {}
      eigen_reverse_lookup = {}
      updateOps = []
      # sync copied stats
      # with tf.control_dependencies(removeNone(stats[0]) +
      # removeNone(stats[1])):
      with tf.control_dependencies([]):
        for stats_var in stats_eigen:
          if stats_var not in computedEigen:
            eigens = tf.self_adjoint_eig(stats_var)
            e = eigens[0]
            Q = eigens[1]
            if self._use_float64:
              e = tf.cast(e, tf.float32)
              Q = tf.cast(Q, tf.float32)
            updateOps.append(e)
            updateOps.append(Q)
            computedEigen[stats_var] = {'e': e, 'Q': Q}
            eigen_reverse_lookup[e] = stats_eigen[stats_var]['e']
            eigen_reverse_lookup[Q] = stats_eigen[stats_var]['Q']

      self.eigen_reverse_lookup = eigen_reverse_lookup
      self.eigen_update_list = updateOps

      if KFAC_DEBUG:
        self.eigen_update_list = [item for item in updateOps]
        with tf.control_dependencies(updateOps):
          updateOps.append(tf.Print(tf.constant(
            0.), [tf.convert_to_tensor('computed factor eigen')]))

    return updateOps

  def applyStatsEigen(self, eigen_list):
    updateOps = []
    print(('updating %d eigenvalue/vectors' % len(eigen_list)))
    for i, (tensor, mark) in enumerate(zip(eigen_list, self.eigen_update_list)):
      stats_eigen_var = self.eigen_reverse_lookup[mark]
      updateOps.append(
        tf.assign(stats_eigen_var, tensor, use_locking=True))

    with tf.control_dependencies(updateOps):
      factor_step_op = tf.assign_add(self.factor_step, 1)
      updateOps.append(factor_step_op)
      if KFAC_DEBUG:
        updateOps.append(tf.Print(tf.constant(
          0.), [tf.convert_to_tensor('updated kfac factors')]))
    return updateOps

  def getKfacPrecondUpdates(self, gradlist, varlist):
    updatelist = []
    vg = 0.

    assert len(self.stats) > 0
    assert len(self.stats_eigen) > 0
    assert len(self.factors) > 0
    counter = 0

    grad_dict = {var: grad for grad, var in zip(gradlist, varlist)}

    for grad, var in zip(gradlist, varlist):
      GRAD_RESHAPE = False
      GRAD_TRANSPOSE = False

      fpropFactoredFishers = self.stats[var]['fprop_concat_stats']
      bpropFactoredFishers = self.stats[var]['bprop_concat_stats']

      if (len(fpropFactoredFishers) + len(bpropFactoredFishers)) > 0:
        counter += 1
        GRAD_SHAPE = grad.get_shape()
        if len(grad.get_shape()) > 2:
          # reshape conv kernel parameters
          KW = int(grad.get_shape()[0])
          KH = int(grad.get_shape()[1])
          C = int(grad.get_shape()[2])
          D = int(grad.get_shape()[3])

          if len(fpropFactoredFishers) > 1 and self._channel_fac:
            # reshape conv kernel parameters into tensor
            grad = tf.reshape(grad, [KW * KH, C, D])
          else:
            # reshape conv kernel parameters into 2D grad
            grad = tf.reshape(grad, [-1, D])
          GRAD_RESHAPE = True
        elif len(grad.get_shape()) == 1:
          # reshape bias or 1D parameters
          D = int(grad.get_shape()[0])

          grad = tf.expand_dims(grad, 0)
          GRAD_RESHAPE = True
        else:
          # 2D parameters
          C = int(grad.get_shape()[0])
          D = int(grad.get_shape()[1])

        if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias:
          # use homogeneous coordinates only works for 2D grad.
          # TO-DO: figure out how to factorize bias grad
          # stack bias grad
          var_assnBias = self.stats[var]['assnBias']
          grad = tf.concat(
            [grad, tf.expand_dims(grad_dict[var_assnBias], 0)], 0)

        # project gradient to eigen space and reshape the eigenvalues
        # for broadcasting
        eigVals = []

        for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']):
          Q = self.stats_eigen[stats]['Q']
          e = detectMinVal(self.stats_eigen[stats][
                             'e'], var, name='act', debug=KFAC_DEBUG)

          Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='act')
          eigVals.append(e)
          grad = gmatmul(Q, grad, transpose_a=True, reduce_dim=idx)

        for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']):
          Q = self.stats_eigen[stats]['Q']
          e = detectMinVal(self.stats_eigen[stats][
                             'e'], var, name='grad', debug=KFAC_DEBUG)

          Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='grad')
          eigVals.append(e)
          grad = gmatmul(grad, Q, transpose_b=False, reduce_dim=idx)
        ##

        #####
        # whiten using eigenvalues
        weightDecayCoeff = 0.
        if var in self._weight_decay_dict:
          weightDecayCoeff = self._weight_decay_dict[var]
          if KFAC_DEBUG:
            print(('weight decay coeff for %s is %f' % (var.name, weightDecayCoeff)))

        if self._factored_damping:
          if KFAC_DEBUG:
            print(('use factored damping for %s' % (var.name)))
          coeffs = 1.
          num_factors = len(eigVals)
          # compute the ratio of two trace norm of the left and right
          # KFac matrices, and their generalization
          if len(eigVals) == 1:
            damping = self._epsilon + weightDecayCoeff
          else:
            damping = tf.pow(
              self._epsilon + weightDecayCoeff, 1. / num_factors)
          eigVals_tnorm_avg = [tf.reduce_mean(
            tf.abs(e)) for e in eigVals]
          for e, e_tnorm in zip(eigVals, eigVals_tnorm_avg):
            eig_tnorm_negList = [
              item for item in eigVals_tnorm_avg if item != e_tnorm]
            if len(eigVals) == 1:
              adjustment = 1.
            elif len(eigVals) == 2:
              adjustment = tf.sqrt(
                e_tnorm / eig_tnorm_negList[0])
            else:
              eig_tnorm_negList_prod = reduce(
                lambda x, y: x * y, eig_tnorm_negList)
              adjustment = tf.pow(
                tf.pow(e_tnorm, num_factors - 1.) / eig_tnorm_negList_prod, 1. / num_factors)
            coeffs *= (e + adjustment * damping)
        else:
          coeffs = 1.
          damping = (self._epsilon + weightDecayCoeff)
          for e in eigVals:
            coeffs *= e
          coeffs += damping

        #grad = tf.Print(grad, [tf.convert_to_tensor('1'), tf.convert_to_tensor(var.name), grad.get_shape()])

        grad /= coeffs

        #grad = tf.Print(grad, [tf.convert_to_tensor('2'), tf.convert_to_tensor(var.name), grad.get_shape()])
        #####
        # project gradient back to euclidean space
        for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']):
          Q = self.stats_eigen[stats]['Q']
          grad = gmatmul(Q, grad, transpose_a=False, reduce_dim=idx)

        for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']):
          Q = self.stats_eigen[stats]['Q']
          grad = gmatmul(grad, Q, transpose_b=True, reduce_dim=idx)
        ##

        #grad = tf.Print(grad, [tf.convert_to_tensor('3'), tf.convert_to_tensor(var.name), grad.get_shape()])
        if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias:
          # use homogeneous coordinates only works for 2D grad.
          # TO-DO: figure out how to factorize bias grad
          # un-stack bias grad
          var_assnBias = self.stats[var]['assnBias']
          C_plus_one = int(grad.get_shape()[0])
          grad_assnBias = tf.reshape(tf.slice(grad,
                                              begin=[
                                                C_plus_one - 1, 0],
                                              size=[1, -1]), var_assnBias.get_shape())
          grad_assnWeights = tf.slice(grad,
                                      begin=[0, 0],
                                      size=[C_plus_one - 1, -1])
          grad_dict[var_assnBias] = grad_assnBias
          grad = grad_assnWeights

        #grad = tf.Print(grad, [tf.convert_to_tensor('4'), tf.convert_to_tensor(var.name), grad.get_shape()])
        if GRAD_RESHAPE:
          grad = tf.reshape(grad, GRAD_SHAPE)

        grad_dict[var] = grad

    print(('projecting %d gradient matrices' % counter))

    for g, var in zip(gradlist, varlist):
      grad = grad_dict[var]
      ### clipping ###
      if KFAC_DEBUG:
        print(('apply clipping to %s' % (var.name)))
      tf.Print(grad, [tf.sqrt(tf.reduce_sum(tf.pow(grad, 2)))], "Euclidean norm of new grad")
      local_vg = tf.reduce_sum(grad * g * (self._lr * self._lr))
      vg += local_vg

    # recale everything
    if KFAC_DEBUG:
      print('apply vFv clipping')

    scaling = tf.minimum(1., tf.sqrt(self._clip_kl / vg))
    if KFAC_DEBUG:
      scaling = tf.Print(scaling, [tf.convert_to_tensor(
        'clip: '), scaling, tf.convert_to_tensor(' vFv: '), vg])
    with tf.control_dependencies([tf.assign(self.vFv, vg)]):
      updatelist = [grad_dict[var] for var in varlist]
      for i, item in enumerate(updatelist):
        updatelist[i] = scaling * item

    return updatelist

  def compute_gradients(self, loss, var_list=None):
    varlist = var_list
    if varlist is None:
      varlist = tf.trainable_variables()
    g = tf.gradients(loss, varlist)

    return [(a, b) for a, b in zip(g, varlist)]

  def apply_gradients_kfac(self, grads):
    g, varlist = list(zip(*grads))

    if len(self.stats_eigen) == 0:
      self.getStatsEigen()

    qr = None
    # launch eigen-decomp on a queue thread
    if self._async:
      print('Use async eigen decomp')
      # get a list of factor loading tensors
      factorOps_dummy = self.computeStatsEigen()

      # define a queue for the list of factor loading tensors
      queue = tf.FIFOQueue(1, [item.dtype for item in factorOps_dummy], shapes=[
        item.get_shape() for item in factorOps_dummy])
      enqueue_op = tf.cond(tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update), tf.convert_to_tensor(
        0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: queue.enqueue(self.computeStatsEigen()), tf.no_op)

      def dequeue_op():
        return queue.dequeue()

      qr = tf.train.QueueRunner(queue, [enqueue_op])

    updateOps = []
    global_step_op = tf.assign_add(self.global_step, 1)
    updateOps.append(global_step_op)

    with tf.control_dependencies([global_step_op]):

      # compute updates
      assert self._update_stats_op != None
      updateOps.append(self._update_stats_op)
      dependency_list = []
      if not self._async:
        dependency_list.append(self._update_stats_op)

      with tf.control_dependencies(dependency_list):
        def no_op_wrapper():
          return tf.group(*[tf.assign_add(self.cold_step, 1)])

        if not self._async:
          # synchronous eigen-decomp updates
          updateFactorOps = tf.cond(tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update),
                                                            tf.convert_to_tensor(0)),
                                                   tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: tf.group(*self.applyStatsEigen(self.computeStatsEigen())), no_op_wrapper)
        else:
          # asynchronous eigen-decomp updates using queue
          updateFactorOps = tf.cond(tf.greater_equal(self.stats_step, self._stats_accum_iter),
                                    lambda: tf.cond(tf.equal(queue.size(), tf.convert_to_tensor(0)),
                                                    tf.no_op,

                                                    lambda: tf.group(
                                                      *self.applyStatsEigen(dequeue_op())),
                                                    ),
                                    no_op_wrapper)

        updateOps.append(updateFactorOps)

        with tf.control_dependencies([updateFactorOps]):
          def gradOp():
            return list(g)

          def getKfacGradOp():
            return self.getKfacPrecondUpdates(g, varlist)
          u = tf.cond(tf.greater(self.factor_step,
                                 tf.convert_to_tensor(0)), getKfacGradOp, gradOp)

          optim = tf.train.MomentumOptimizer(
            self._lr * (1. - self._momentum), self._momentum)
          #optim = tf.train.AdamOptimizer(self._lr, epsilon=0.01)

          def optimOp():
            def updateOptimOp():
              if self._full_stats_init:
                return tf.cond(tf.greater(self.factor_step, tf.convert_to_tensor(0)), lambda: optim.apply_gradients(list(zip(u, varlist))), tf.no_op)
              else:
                return optim.apply_gradients(list(zip(u, varlist)))
            if self._full_stats_init:
              return tf.cond(tf.greater_equal(self.stats_step, self._stats_accum_iter), updateOptimOp, tf.no_op)
            else:
              return tf.cond(tf.greater_equal(self.sgd_step, self._cold_iter), updateOptimOp, tf.no_op)
          updateOps.append(optimOp())

    return tf.group(*updateOps), qr

  def apply_gradients(self, grads):
    coldOptim = tf.train.MomentumOptimizer(
      self._cold_lr, self._momentum)

    def coldSGDstart():
      sgd_grads, sgd_var = zip(*grads)

      if self.max_grad_norm != None:
        sgd_grads, sgd_grad_norm = tf.clip_by_global_norm(sgd_grads,self.max_grad_norm)

      sgd_grads = list(zip(sgd_grads,sgd_var))

      sgd_step_op = tf.assign_add(self.sgd_step, 1)
      coldOptim_op = coldOptim.apply_gradients(sgd_grads)
      if KFAC_DEBUG:
        with tf.control_dependencies([sgd_step_op, coldOptim_op]):
          sgd_step_op = tf.Print(
            sgd_step_op, [self.sgd_step, tf.convert_to_tensor('doing cold sgd step')])
      return tf.group(*[sgd_step_op, coldOptim_op])

    kfacOptim_op, qr = self.apply_gradients_kfac(grads)

    def warmKFACstart():
      return kfacOptim_op

    return tf.cond(tf.greater(self.sgd_step, self._cold_iter), warmKFACstart, coldSGDstart), qr

  def minimize(self, loss, loss_sampled, var_list=None):
    grads = self.compute_gradients(loss, var_list=var_list)
    update_stats_op = self.compute_and_apply_stats(
      loss_sampled, var_list=var_list)
    return self.apply_gradients(grads)
